{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Kalman Filter\n",
    "\n",
    "Kalman filters are linear models for state estimation of dynamic systems [1].  They have been the <i>de facto</i> standard in many robotics and tracking/prediction applications because they are well suited for systems with uncertainty about an observable dynamic process.  They use a \"observe, predict, correct\" paradigm to extract information from an otherwise noisy signal. In Pyro, we can build differentiable Kalman filters with learnable parameters using the [`pyro.contrib.tracking` library](http://docs.pyro.ai/en/dev/contrib.tracking.html#module-pyro.contrib.tracking.extended_kalman_filter)\n",
    "\n",
    "## Dynamic process\n",
    "\n",
    "To start, consider this simple motion model:\n",
    "\n",
    "$$ X_{k+1} = FX_k + \\mathbf{W}_k $$\n",
    "$$ \\mathbf{Z}_k = HX_k + \\mathbf{V}_k $$\n",
    "\n",
    "where $k$ is the state, $X$ is the signal estimate, $Z_k$ is the observed value at timestep $k$, $\\mathbf{W}_k$ and  $\\mathbf{V}_k$ are independent noise processes (ie $\\mathbb{E}[w_k v_j^T] = 0$ for all $j, k$) which we'll approximate as Gaussians. Note that the state transitions are linear."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Kalman Update\n",
    "At each time step, we perform a prediction for the mean and covariance:\n",
    "$$ \\hat{X}_k = F\\hat{X}_{k-1}$$\n",
    "$$\\hat{P}_k = FP_{k-1}F^T + Q$$\n",
    "and a correction for the measurement:\n",
    "$$ K_k = \\hat{P}_k H^T(H\\hat{P}_k H^T + R)^{-1}$$\n",
    "$$ X_k = \\hat{X}_k + K_k(z_k - H\\hat{X}_k)$$\n",
    "$$ P_k = (I-K_k H)\\hat{P}_k$$\n",
    "\n",
    "where $X$ is the position estimate, $P$ is the covariance matrix, $K$ is the Kalman Gain, and $Q$ and $R$ are covariance matrices.\n",
    "\n",
    "For an in-depth derivation, see \\[1\\]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Nonlinear Estimation: Extended Kalman Filter\n",
    "\n",
    "What if our system is non-linear, eg in GPS navigation?  Consider the following non-linear system:\n",
    "\n",
    "$$ X_{k+1} = \\mathbf{f}(X_k) + \\mathbf{W}_k $$\n",
    "$$ \\mathbf{Z}_k = \\mathbf{h}(X_k) + \\mathbf{V}_k $$\n",
    "\n",
    "Notice that $\\mathbf{f}$ and $\\mathbf{h}$ are now (smooth) non-linear functions.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The Extended Kalman Filter (EKF) attacks this problem by using a local linearization of the Kalman filter via a [Taylors Series expansion](https://en.wikipedia.org/wiki/Taylor_series).\n",
    "\n",
    "$$ f(X_k, k) \\approx f(x_k^R, k) + \\mathbf{H}_k(X_k - x_k^R) + \\cdots$$\n",
    "\n",
    "where $\\mathbf{H}_k$ is the Jacobian matrix at time $k$, $x_k^R$ is the previous optimal estimate, and we ignore the higher order terms.  At each time step, we compute a Jacobian conditioned the previous predictions (this computation is handled by Pyro under the hood), and use the result to perform a prediction and update.\n",
    "\n",
    "Omitting the derivations, the modification to the above predictions are now:\n",
    "$$ \\hat{X}_k \\approx \\mathbf{f}(X_{k-1}^R)$$\n",
    "$$ \\hat{P}_k = \\mathbf{H}_\\mathbf{f}(X_{k-1})P_{k-1}\\mathbf{H}_\\mathbf{f}^T(X_{k-1}) + Q$$\n",
    "and the updates are now:\n",
    "$$ X_k \\approx \\hat{X}_k + K_k\\big(z_k - \\mathbf{h}(\\hat{X}_k)\\big)$$\n",
    "$$ K_k = \\hat{P}_k \\mathbf{H}_\\mathbf{h}(\\hat{X}_k) \\Big(\\mathbf{H}_\\mathbf{h}(\\hat{X}_k)\\hat{P}_k \\mathbf{H}_\\mathbf{h}(\\hat{X}_k) + R_k\\Big)^{-1} $$\n",
    "$$ P_k = \\big(I - K_k \\mathbf{H}_\\mathbf{h}(\\hat{X}_k)\\big)\\hat{P}_K$$\n",
    "\n",
    "In Pyro, all we need to do is create an `EKFState` object and use its `predict` and `update` methods. Pyro will do exact inference to compute the innovations and we will use SVI to learn a MAP estimate of the position and measurement covariances.\n",
    "\n",
    "As an example, let's look at an object moving at near-constant velocity in 2-D in a discrete time space over 100 time steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import math\n",
    "\n",
    "import torch\n",
    "import pyro\n",
    "import pyro.distributions as dist\n",
    "from pyro.contrib.autoguide import AutoDelta\n",
    "from pyro.optim import Adam\n",
    "from pyro.infer import SVI, Trace_ELBO, config_enumerate\n",
    "from pyro.contrib.tracking.extended_kalman_filter import EKFState\n",
    "from pyro.contrib.tracking.distributions import EKFDistribution\n",
    "from pyro.contrib.tracking.dynamic_models import NcvContinuous\n",
    "from pyro.contrib.tracking.measurements import PositionMeasurement\n",
    "\n",
    "smoke_test = ('CI' in os.environ)\n",
    "assert pyro.__version__.startswith('0.3.0')\n",
    "pyro.enable_validation(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dt = 1e-2\n",
    "num_frames = 10\n",
    "dim = 4\n",
    "\n",
    "# Continuous model\n",
    "ncv = NcvContinuous(dim, 2.0)\n",
    "\n",
    "# Truth trajectory\n",
    "xs_truth = torch.zeros(num_frames, dim)\n",
    "# initial direction\n",
    "theta0_truth = 0.0\n",
    "# initial state\n",
    "with torch.no_grad():\n",
    "    xs_truth[0, :] = torch.tensor([0.0, 0.0,  math.cos(theta0_truth), math.sin(theta0_truth)])\n",
    "    for frame_num in range(1, num_frames):\n",
    "        # sample independent process noise\n",
    "        dx = pyro.sample('process_noise_{}'.format(frame_num), ncv.process_noise_dist(dt))\n",
    "        xs_truth[frame_num, :] = ncv(xs_truth[frame_num-1, :], dt=dt) + dx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let's specify the measurements. Notice that we only measure the positions of the particle."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Measurements\n",
    "measurements = []\n",
    "mean = torch.zeros(2)\n",
    "# no correlations\n",
    "cov = 1e-5 * torch.eye(2)\n",
    "with torch.no_grad():\n",
    "    # sample independent measurement noise\n",
    "    dzs = pyro.sample('dzs', dist.MultivariateNormal(mean, cov).expand((num_frames,)))\n",
    "    # compute measurement means\n",
    "    zs = xs_truth[:, :2] + dzs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll use a [Delta autoguide](http://docs.pyro.ai/en/dev/contrib.autoguide.html#autodelta) to learn MAP estimates of the position and measurement covariances. The `EKFDistribution` computes the joint log density of all of the EKF states given a tensor of sequential measurements."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(data):\n",
    "    # a HalfNormal can be used here as well\n",
    "    R = pyro.sample('pv_cov', dist.HalfCauchy(2e-6)) * torch.eye(4)\n",
    "    Q = pyro.sample('measurement_cov', dist.HalfCauchy(1e-6)) * torch.eye(2)\n",
    "    # observe the measurements\n",
    "    pyro.sample('track_{}'.format(i), EKFDistribution(xs_truth[0], R, ncv,\n",
    "                                                      Q, time_steps=num_frames),\n",
    "                obs=data)\n",
    "    \n",
    "guide = AutoDelta(model)  # MAP estimation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optim = pyro.optim.Adam({'lr': 2e-2})\n",
    "svi = SVI(model, guide, optim, loss=Trace_ELBO(retain_graph=True))\n",
    "\n",
    "pyro.set_rng_seed(0)\n",
    "pyro.clear_param_store()\n",
    "\n",
    "for i in range(250 if not smoke_test else 2):\n",
    "    loss = svi.step(zs)\n",
    "    if not i % 10:\n",
    "        print('loss: ', loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# retrieve states for visualization\n",
    "R = guide()['pv_cov'] * torch.eye(4)\n",
    "Q = guide()['measurement_cov'] * torch.eye(2)\n",
    "ekf_dist = EKFDistribution(xs_truth[0], R, ncv, Q, time_steps=num_frames)\n",
    "states= ekf_dist.filter_states(zs)"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/html"
   },
   "source": [
    "<center><figure>\n",
    "    <table>\n",
    "        <tr>\n",
    "            <td> \n",
    "                <img src=\"_static/img/ekf_track.png\" style=\"width: 450px;\">\n",
    "            </td>\n",
    "        </tr>\n",
    "    </table> <center>\n",
    "    <figcaption> \n",
    "        <font size=\"+1\"><b>Figure 1:</b>True track and EKF prediction with error. </font> \n",
    "    </figcaption> </center>\n",
    "</figure></center>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
